application: targetted data collection

knowing what we know, where and when should we plan to next collect data?

planning the next test

survival analysis

Code
library(cmdstanr)

survival_model <- cmdstan_model(stan_file = "survival.stan")
survival_model$format()
data {
  int<lower=0> n_meas; // number of observations
  vector<lower=0>[n_meas] obs_time; // time of observation
  vector<lower=0>[n_meas] fail_lb; // lower bound of failure time
  vector<lower=0>[n_meas] fail_ub; // status of observation
  
  array[n_meas] int<lower=0, upper=1> fail_status; // if a failure has occured, we have interval-censored data
  
  int<lower=0> n_pred; // number of predictions
  vector<lower=0>[n_pred] pred_time; // time of prediction
}
parameters {
  real<lower=0> scale; // scale parameter
  real<lower=0> shape; // shape parameter
}
model {
  //priors
  scale ~ normal(8, 3);
  shape ~ normal(6, 3);
  
  //likelihood
  for (n in 1 : n_meas) {
    if (fail_status[n] == 0) {
      target += log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      target += log(loglogistic_cdf(fail_ub[n] | scale, shape)
                    - loglogistic_cdf(fail_lb[n] | scale, shape));
    }
  }
}
generated quantities {
  vector[n_meas] log_lik;
  vector[n_pred] p_fail_pred;
  
  for (n in 1 : n_meas) {
    if (fail_status[n] == 1) {
      log_lik[n] = log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      log_lik[n] = log(loglogistic_cdf(fail_ub[n] | scale, shape)
                       - loglogistic_cdf(fail_lb[n] | scale, shape));
    }
  }
  
  for (n in 1 : n_pred) {
    p_fail_pred[n] = loglogistic_cdf(pred_time[n] | scale, shape);
  }
}
Code
import cmdstanpy

survival_model = cmdstanpy.CmdStanModel(stan_file = "survival.stan")
INFO:cmdstanpy:found newer exe file, not recompiling
Code
stan_code = survival_model.code()

from pygments import highlight
from pygments.lexers import StanLexer
from pygments.formatters import NullFormatter

formatted_stan_code = highlight(stan_code, StanLexer(), NullFormatter())

print(formatted_stan_code)
data {
  int <lower = 0> n_meas;                   // number of observations
  vector <lower = 0> [n_meas] obs_time;     // time of observation
  vector <lower = 0> [n_meas] fail_lb;      // lower bound of failure time
  vector <lower = 0> [n_meas] fail_ub;      // status of observation

  array [n_meas] int<lower = 0, upper = 1> fail_status; // if a failure has occured, we have interval-censored data

  int <lower = 0> n_pred;                   // number of predictions
  vector <lower = 0> [n_pred] pred_time;    // time of prediction
}

parameters {
  real <lower = 0> scale; // scale parameter
  real <lower = 0> shape; // shape parameter
}

model{
    //priors
    scale ~ normal(8, 3);
    shape ~ normal(6, 3);

    //likelihood
    for(n in 1:n_meas){
        if(fail_status[n] == 0){
            target += log1m(loglogistic_cdf(obs_time[n] | scale, shape));
        } else {
            target += log(
                          loglogistic_cdf(fail_ub[n] | scale, shape) - 
                          loglogistic_cdf(fail_lb[n] | scale, shape)
                        );
        }
    }
}

generated quantities {
  vector [n_meas] log_lik;
  vector [n_pred] p_fail_pred;

  for(n in 1:n_meas){
    if(fail_status[n] == 1){
      log_lik[n] = log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      log_lik[n] = log(
                        loglogistic_cdf(fail_ub[n] | scale, shape) - 
                        loglogistic_cdf(fail_lb[n] | scale, shape)
                      );
    }
  }

  for(n in 1:n_pred){
    p_fail_pred[n] = loglogistic_cdf(pred_time[n] | scale, shape);
  }
  
}
Code
using Turing
using LogExpFunctions: log1mexp

include("../../data/LogLogisticDistribution.jl")
LogLogisticDistribution (generic function with 1 method)
Code

@model function loglogistic_survival(
    obs_time::Vector{Float64},     # time of observation
    fail_lb::Vector{Float64},      # lower bound of failure time
    fail_ub::Vector{Float64},      # upper bound of failure time
    fail_status::Vector{Int}   # 0 if right-censored, 1 if interval-censored
)
    # Priors
    scale ~ Normal(8, 3) |> d -> truncated(d, lower = 0)
    shape ~ Normal(6, 3) |> d -> truncated(d, lower = 0)

    # Create distribution with current parameters
    d = LogLogisticDistribution(scale, shape)

    # Likelihood
    for i in eachindex(obs_time)
        if fail_status[i] == 0
            # Right censored: P(T > obs_time)
            Turing.@addlogprob! log(survival(d, obs_time[i]))
        else
            # Interval censored: P(lb < T < ub)
            Turing.@addlogprob! log(
                cdf(d, fail_ub[i]) - cdf(d, fail_lb[i])
            )
        end
    end
end
loglogistic_survival (generic function with 2 methods)

survival analysis

Code
library(tidyverse)

failure_data <- read_csv("../../data/failures.csv")

model_data <- list(
  n_meas = nrow(failure_data),
  obs_time = rep(12, nrow(failure_data)),
  fail_lb = failure_data$fail_lb,
  fail_ub = failure_data$fail_ub,
  fail_status = is.finite(failure_data$fail_ub) |> as.integer(),
  n_pred = 100,
  pred_time = seq(from = 0, to = 20, length.out = 100)
)

survival_fit <- survival_model$sample(
  data = model_data,
  chains = 4,
  parallel_chains = parallel::detectCores(),
  seed = 231123,
  iter_warmup = 2000,
  iter_sampling = 2000
)
Running MCMC with 4 chains, at most 16 in parallel...

Chain 1 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 1 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 1 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 1 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 1 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 1 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 1 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 1 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 1 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 1 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 1 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 1 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 1 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 1 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 1 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 1 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 1 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 1 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 1 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 1 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 1 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 1 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 1 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 1 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 1 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 1 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 1 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 1 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 1 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 1 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 1 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 2 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 2 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 2 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 2 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 2 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 2 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 2 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 2 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 2 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 2 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 2 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 2 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 2 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 2 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 2 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 2 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 2 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 2 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 2 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 2 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 2 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 2 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 2 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 2 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 2 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 2 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 2 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 2 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 2 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 2 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 2 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 2 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 2 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 2 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 2 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 2 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 2 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 2 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 3 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 3 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 3 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 3 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 3 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 3 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 3 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 3 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 3 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 3 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 3 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 3 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 3 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 3 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 3 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 3 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 3 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 3 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 3 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 3 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 3 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 3 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 3 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 3 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 3 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 3 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 3 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 3 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 3 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 3 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 3 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 3 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 3 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 3 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 3 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 3 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 3 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 3 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 3 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 3 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 4 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 4 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 4 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 4 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 4 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 4 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 4 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 4 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 4 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 4 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 4 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 4 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 4 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 4 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 4 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 4 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 4 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 4 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 4 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 4 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 4 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 4 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 4 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 4 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 4 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 4 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 4 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 4 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 4 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 4 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 4 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 4 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 4 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 4 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 4 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 4 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 4 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 4 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 4 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 4 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 4 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 1 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 1 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 1 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 1 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 1 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 1 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 1 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 1 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 1 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 1 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 1 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 1 finished in 0.2 seconds.
Chain 2 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 2 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 2 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 2 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 2 finished in 0.2 seconds.
Chain 3 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 3 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 3 finished in 0.2 seconds.
Chain 4 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 4 finished in 0.2 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.2 seconds.
Total execution time: 0.3 seconds.
Code
survival_fit$summary()
# A tibble: 123 × 10
   variable     mean median    sd   mad     q5    q95  rhat ess_bulk ess_tail
   <chr>       <dbl>  <dbl> <dbl> <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
 1 lp__       -33.0  -32.7  1.04  0.747 -35.0  -32.0   1.00    3431.    4734.
 2 scale        9.48   9.45 0.693 0.678   8.38  10.6   1.00    5912.    4789.
 3 shape        5.62   5.57 1.09  1.10    3.93   7.50  1.00    5308.    4987.
 4 log_lik[1]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 5 log_lik[2]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 6 log_lik[3]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 7 log_lik[4]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 8 log_lik[5]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 9 log_lik[6]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
10 log_lik[7]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
# ℹ 113 more rows
Code
import polars as pl, numpy as np
import multiprocessing

failure_data = pl.read_csv("../../data/failures.csv").with_columns([
    pl.col("fail_ub").cast(pl.Float64),
    pl.col("fail_lb").cast(pl.Float64)
])

# Define a large finite number to substitute for infinity.
large_num = 1e10

# Convert the fail_ub array from the failure_data, and replace inf values.
fail_ub = failure_data["fail_ub"].to_numpy().copy()
fail_ub[~np.isfinite(fail_ub)] = large_num

# Prepare your model_data dictionary.
model_data = {
    "n_meas": failure_data.shape[0],
    "obs_time": [12] * failure_data.shape[0],
    "fail_lb": failure_data["fail_lb"].to_numpy(),
    "fail_ub": fail_ub,
    "fail_status": (failure_data["fail_ub"].is_finite().cast(pl.Int64)).to_numpy(),
    "n_pred": 100,
    "pred_time": np.linspace(start=0, stop=20, num=100)
}

survival_fit = survival_model.sample(
  data = model_data,
  chains = 4,
  parallel_chains = 1,
  seed = 231123,
  iter_warmup = 2000,
  iter_sampling = 2000
)
                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan start processing

chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status


chain 3 |          | 00:00 Status



chain 4 |          | 00:00 Status
chain 1 |######4   | 00:00 Iteration: 2400 / 4000 [ 60%]  (Sampling)

chain 2 |2         | 00:00 Status

chain 2 |#######3  | 00:00 Iteration: 2800 / 4000 [ 70%]  (Sampling)


chain 3 |2         | 00:00 Status


chain 3 |#######6  | 00:00 Iteration: 2900 / 4000 [ 72%]  (Sampling)



chain 4 |2         | 00:00 Status



chain 4 |#######3  | 00:00 Iteration: 2800 / 4000 [ 70%]  (Sampling)
chain 1 |##########| 00:00 Sampling completed                       

chain 2 |##########| 00:00 Sampling completed                       

chain 3 |##########| 00:00 Sampling completed                       

chain 4 |##########| 00:00 Sampling completed                       
INFO:cmdstanpy:CmdStan done processing.
Code
survival_fit.summary()
                   Mean     MCSE  StdDev     5%  ...   95%   N_Eff  N_Eff/s  R_hat
name                                             ...                              
lp__             -33.00  0.01800   1.000 -35.00  ... -32.0  3500.0   6500.0    1.0
scale              9.50  0.00910   0.690   8.40  ...  11.0  5800.0  11000.0    1.0
shape              5.60  0.01500   1.100   3.90  ...   7.5  5400.0   9900.0    1.0
log_lik[1]        -1.60  0.00500   0.380  -2.30  ...  -1.0  5750.0  10590.0    1.0
log_lik[2]        -1.60  0.00500   0.380  -2.30  ...  -1.0  5750.0  10590.0    1.0
...                 ...      ...     ...    ...  ...   ...     ...      ...    ...
p_fail_pred[96]    0.97  0.00033   0.022   0.93  ...   1.0  4457.0   8208.0    1.0
p_fail_pred[97]    0.98  0.00032   0.021   0.93  ...   1.0  4452.0   8199.0    1.0
p_fail_pred[98]    0.98  0.00031   0.021   0.94  ...   1.0  4448.0   8191.0    1.0
p_fail_pred[99]    0.98  0.00030   0.020   0.94  ...   1.0  4444.0   8184.0    1.0
p_fail_pred[100]   0.98  0.00029   0.019   0.94  ...   1.0  4441.0   8178.0    1.0

[123 rows x 9 columns]
Code
using CSV, DataFrames

failure_data = CSV.read("../../data/failures.csv", DataFrame)

survival_fit = loglogistic_survival(
    repeat([12.0], nrow(failure_data)),
    failure_data.fail_lb,
    failure_data.fail_ub,
    isfinite.(failure_data.fail_ub) |> x -> Int.(x)
) |> model -> sample(MersenneTwister(231123), model, NUTS(), MCMCThreads(), 2000, 4)
Chains MCMC chain (2000×14×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 9.57 seconds
Compute duration  = 7.51 seconds
parameters        = scale, shape
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

       scale    9.4688    0.6886    0.0090   5854.0791   4655.8635    1.0006   ⋯
       shape    5.6014    1.0963    0.0143   5854.6561   5483.7428    1.0012   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

       scale    8.1597    9.0166    9.4565    9.9064   10.9001
       shape    3.6211    4.8263    5.5506    6.3210    7.8421

expected information gain

expected information gain

can be computationally intensive

expected information gain

  • quantify uncertainty in posterior predictions
  • identify prospetive data collection options
  • generate all possible outcome scenarios
    • here (helpfully): failure or no failure
  • for each outcome:
    • simpulate the data collection and re-fit the model
    • quantify uncertainty in the new posterior predictions
    • find the difference (reduction in uncertainty with the new data)
    • weight the reduction by the probability of the outcome
  • compare the expected “information gain” to rank order data collection options

measures of uncertainty

  • entropy?
  • log-likelihood?
  • kernel density estimation?
  • variance?
Code
post_pred |> head()
# A tibble: 6 × 5
  Parameter      Chain Iteration value  time
  <chr>          <int>     <int> <dbl> <dbl>
1 p_fail_pred[1]     1         1     0     0
2 p_fail_pred[1]     1         2     0     0
3 p_fail_pred[1]     1         3     0     0
4 p_fail_pred[1]     1         4     0     0
5 p_fail_pred[1]     1         5     0     0
6 p_fail_pred[1]     1         6     0     0
Code
estimate_uncertainty <- function(posterior = post_pred) {
  posterior |>
    group_by(time) |>
    summarise(uncertainty_base = var(value))
}

estimate_uncertainty() |> head()
# A tibble: 6 × 2
   time uncertainty_base
  <dbl>            <dbl>
1 0             0       
2 0.202         4.57e-12
3 0.404         1.39e-10
4 0.606         1.16e- 9
5 0.808         5.65e- 9
6 1.01          2.05e- 8
Code
post_pred.head()
shape: (5, 5)
Chain Iteration Parameter value time
i64 i64 str f64 f64
1 1 "p_fail_pred[1]" 0.0 0.0
1 2 "p_fail_pred[1]" 0.0 0.0
1 3 "p_fail_pred[1]" 0.0 0.0
1 4 "p_fail_pred[1]" 0.0 0.0
1 5 "p_fail_pred[1]" 0.0 0.0
Code
def estimate_uncertainty(posterior=post_pred):
    # In Polars, we need to use pl.col for column references
    return (posterior
            .group_by("time")
            .agg(uncertainty=pl.col("value").var())
            .sort("time"))

estimate_uncertainty().head()
shape: (5, 2)
time uncertainty
f64 f64
0.0 0.0
0.20202 4.5745e-12
0.40404 1.3886e-10
0.606061 1.1557e-9
0.808081 5.6478e-9

expected information gain

Code
estimate_information_gain <- function(proposed_time) {
  # we need new datasets (hypothesising our next data point)
  fail_data <- model_data -> no_fail_data
  
  # case A: we observe a failure
  fail_data$n_meas <- fail_data$n_meas + 1
  fail_data$obs_time <- c(fail_data$obs_time, proposed_time)
  fail_data$fail_lb <- c(fail_data$fail_lb, proposed_time - 1.5)
  fail_data$fail_ub <- c(fail_data$fail_ub, proposed_time)
  fail_data$fail_status <- c(fail_data$fail_status, 1)

  # case B: we do not observe a failure
  no_fail_data$n_meas <- no_fail_data$n_meas + 1
  no_fail_data$obs_time <- c(no_fail_data$obs_time, proposed_time)
  no_fail_data$fail_lb <- c(no_fail_data$fail_lb, proposed_time)
  no_fail_data$fail_ub <- c(no_fail_data$fail_ub, Inf)
  no_fail_data$fail_status <- c(no_fail_data$fail_status, 0)

  # re-fitting our models for each possible outcome
  fail_fit <- survival_model$sample(
    data = fail_data,
    chains = 4,
    parallel_chains = parallel::detectCores(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )

  no_fail_fit <- survival_model$sample(
    data = no_fail_data,
    chains = 4,
    parallel_chains = parallel::detectCores(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )

  # quantify uncertainty in the new predictions
  base_uncertainties <- estimate_uncertainty()
    
  fail_uncertainties <- fail_fit |>
    DomDF::tidy_mcmc_draws(params = pred_params) |>
    mutate(time = rep(x = model_data$pred_time, 
           each = fail_fit$metadata()$iter_sampling * length(fail_fit$metadata()$id))) |>
    estimate_uncertainty() |> rename(uncertainty_fail = uncertainty_base)
    
  no_fail_uncertainties <- no_fail_fit |>
    DomDF::tidy_mcmc_draws(params = pred_params) |>
    mutate(time = rep(x = model_data$pred_time, 
           each = no_fail_fit$metadata()$iter_sampling * length(no_fail_fit$metadata()$id))) |>
    estimate_uncertainty() |> rename(uncertainty_no_fail = uncertainty_base)
    
  # what are the prior probabilities of each outcome?
  p_fail <- post_pred |>
    filter(abs(time - proposed_time) == min(abs(time - proposed_time))) |>
    summarise(p = mean(value)) |>
    pull(p)
    
  information_gains <- base_uncertainties |>
    left_join(fail_uncertainties, by = "time") |>
    left_join(no_fail_uncertainties, by = "time") |>
    mutate(
      # calculate a weighted uncertainty reduction
      weighted_reduction = pmax(0, (uncertainty_base - uncertainty_fail)) * p_fail +
                           pmax(0, (uncertainty_base - uncertainty_no_fail)) * (1 - p_fail)

    )
    
  # return the expected information gain
  return(information_gains$weighted_reduction |> sum())
}
Code
import copy

def estimate_information_gain(proposed_time):
  fail_data = copy.deepcopy(model_data)
  no_fail_data = copy.deepcopy(model_data)
  
  fail_data["obs_time"] = model_data["obs_time"].tolist() if hasattr(model_data["obs_time"], "tolist") else list(model_data["obs_time"])
  fail_data["fail_lb"]   = model_data["fail_lb"].tolist() if hasattr(model_data["fail_lb"], "tolist") else list(model_data["fail_lb"])
  fail_data["fail_ub"]   = model_data["fail_ub"].tolist() if hasattr(model_data["fail_ub"], "tolist") else list(model_data["fail_ub"])
  fail_data["fail_status"] = model_data["fail_status"].tolist() if hasattr(model_data["fail_status"], "tolist") else list(model_data["fail_status"])

  no_fail_data["obs_time"] = model_data["obs_time"].tolist() if hasattr(model_data["obs_time"], "tolist") else list(model_data["obs_time"])
  no_fail_data["fail_lb"]   = model_data["fail_lb"].tolist() if hasattr(model_data["fail_lb"], "tolist") else list(model_data["fail_lb"])
  no_fail_data["fail_ub"]   = model_data["fail_ub"].tolist() if hasattr(model_data["fail_ub"], "tolist") else list(model_data["fail_ub"])
  no_fail_data["fail_status"] = model_data["fail_status"].tolist() if hasattr(model_data["fail_status"], "tolist") else list(model_data["fail_status"])

  fail_data["n_meas"] = model_data["n_meas"] + 1
  fail_data["obs_time"].append(proposed_time)
  fail_data["fail_lb"].append(proposed_time - 1.5)
  fail_data["fail_ub"].append(proposed_time)
  fail_data["fail_status"].append(1)

  no_fail_data["n_meas"] = model_data["n_meas"] + 1
  no_fail_data["obs_time"].append(proposed_time)
  no_fail_data["fail_lb"].append(proposed_time)
  no_fail_data["fail_ub"].append(large_num)  
  no_fail_data["fail_status"].append(0)
    
  fail_fit = survival_model.sample(
    data = fail_data,
    chains = 4,
    parallel_chains = multiprocessing.cpu_count(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )
  
  no_fail_fit = survival_model.sample(
    data = no_fail_data,
    chains = 4,
    parallel_chains = multiprocessing.cpu_count(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )
    
  window = 2.0
  
  base_uncertainties = (
    post_pred
    .filter(abs(pl.col("time") - proposed_time) <= window)
    .group_by("time")
    .agg(uncertainty_base=pl.col("value").var())
    .sort("time")
  )

  fail_post = (
    process_mcmc_draws(fail_fit, pred_params)
    .filter((pl.col("time") - proposed_time).abs() <= window)
    .group_by("time")
    .agg(pl.col("value").var().alias("uncertainty_fail"))
    .sort("time")
  )
  
  no_fail_post = (
    process_mcmc_draws(no_fail_fit, pred_params)
    .filter((pl.col("time") - proposed_time).abs() <= window)
    .group_by("time")
    .agg(pl.col("value").var().alias("uncertainty_no_fail"))
    .sort("time")
  )
    
  min_diff = (
    post_pred
    .select((pl.col("time") - proposed_time).abs().alias("diff"))
    .select(pl.col("diff").min())
    .item()
  )
    
  p_fail = (
    post_pred
    .filter((pl.col("time") - proposed_time).abs() == min_diff)
    .select(pl.col("value").mean().alias("p"))
    .item()
  )
    
  information_gains = (
    base_uncertainties
    .join(fail_post, on="time", how="left")
    .join(no_fail_post, on="time", how="left")
    .with_columns(
        weighted_reduction=(
            pl.when(pl.col("uncertainty_base") - pl.col("uncertainty_fail") > 0)
              .then(pl.col("uncertainty_base") - pl.col("uncertainty_fail"))
              .otherwise(0) * p_fail +
            pl.when(pl.col("uncertainty_base") - pl.col("uncertainty_no_fail") > 0)
              .then(pl.col("uncertainty_base") - pl.col("uncertainty_no_fail"))
              .otherwise(0) * (1 - p_fail)
        )
    )
  )
    
  # Return the total information gain (sum over weighted_reduction)
  total_gain = information_gains.select(pl.col("weighted_reduction")).sum().item()
  return total_gain

expected information gain

experimental design

  • what do we want to achieve with data collection?
    • reduce uncertainty in predictions?
    • test a hypothesis?
    • support decision-making? (see “value of information analysis”)

break?